import copy
import inspect
from typing import List, Union

import torch
import torch.nn as nn
import lightning

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmpl.registry import LOGGERS


def register_pl_loggers() -> List[str]:
    """Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry.

    Returns:
        List[str]: A list of registered optimizers' name.
    """
    pl_loggers = []
    for module_name in dir(lightning.pytorch.loggers):
        if module_name.startswith('__'):
            continue
        _logger = getattr(lightning.pytorch.loggers, module_name)
        if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger):
            LOGGERS.register_module(module=_logger)
            pl_loggers.append(module_name)
    return pl_loggers


PL_LOGGERS = register_pl_loggers()


def register_dadaptation_optimizers() -> List[str]:
    """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry.

    Returns:
        List[str]: A list of registered optimizers' name.
    """
    dadaptation_optimizers = []
    try:
        import dadaptation
    except ImportError:
        pass
    else:
        for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']:
            _optim = getattr(dadaptation, module_name)
            if inspect.isclass(_optim) and issubclass(_optim,
                                                      torch.optim.Optimizer):
                OPTIMIZERS.register_module(module=_optim)
                dadaptation_optimizers.append(module_name)
    return dadaptation_optimizers


# DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()


def register_lion_optimizers() -> List[str]:
    """Register Lion optimizer to the ``OPTIMIZERS`` registry.

    Returns:
        List[str]: A list of registered optimizers' name.
    """
    optimizers = []
    try:
        from lion_pytorch import Lion
    except ImportError:
        pass
    else:
        OPTIMIZERS.register_module(module=Lion)
        optimizers.append('Lion')
    return optimizers


# LION_OPTIMIZERS = register_lion_optimizers()


def build_optim_wrapper(model: nn.Module,
                        cfg: Union[dict, Config, ConfigDict]):
    """Build function of OptimWrapper.

    If ``constructor`` is set in the ``cfg``, this method will build an
    optimizer wrapper constructor, and use optimizer wrapper constructor to
    build the optimizer wrapper. If ``constructor`` is not set, the
    ``DefaultOptimWrapperConstructor`` will be used by default.

    Args:
        model (nn.Module): Model to be optimized.
        cfg (dict): Config of optimizer wrapper, optimizer constructor and
            optimizer.

    Returns:
        OptimWrapper: The built optimizer wrapper.
    """
    optim_wrapper_cfg = copy.deepcopy(cfg)
    constructor_type = optim_wrapper_cfg.pop('constructor',
                                             'DefaultOptimWrapperConstructor')
    paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)

    # Since the current generation of NPU(Ascend 910) only supports
    # mixed precision training, here we turn on mixed precision by default
    # on the NPU to make the training normal
    if is_npu_available():
        optim_wrapper_cfg['type'] = 'AmpOptimWrapper'

    optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
        dict(
            type=constructor_type,
            optim_wrapper_cfg=optim_wrapper_cfg,
            paramwise_cfg=paramwise_cfg))
    optim_wrapper = optim_wrapper_constructor(model)
    return optim_wrapper
